#include <stdio.h>
#include <openbabel/mol.h>
#include <openbabel/obconversion.h>
#include <iostream>
#include <set>
#include <cmath>

using namespace std;
using namespace OpenBabel;

//------------------------------------------------------------------------------
enum atomType { SI0, SI1, OX, OY0, OY1, HY };
enum fixType {TOP, BOTTOM};

// function prototypes
void AddHydrogens(OBMol& mol);
void AlignStructure(OBMol& mol);
bool BuildTLeap(OBMol& probemol, OBMol& bulkmol, const string& name);
void CleanStructure(OBMol& mol);
int  CountFreeOxygens(OBMol& mol);
int  CountNbrSilicons(OBAtom* p_atom);
int  CountOxygens(OBMol& mol);
void CreateAtom(FILE* file, OBAtom* p_atom);
void CutStructure(OBMol& imol, OBMol& omol, double coneAngle, double coneRadius);
double GetFixedZValue(OBMol& mol, fixType fType);
string GetSiName(OBAtom* p_atom);
bool IsSi0(OBAtom* p_atom);
bool IsFixedAtom(OBAtom* p_atom, double refZ=0, double tolerance=0);
void InitFixedAtoms(OBMol& mol, fixType fType, double tolerance);
bool ReadStructure(OBMol& mol, const string& name);
void WriteBonds(FILE* file, OBMol& mol, const string &resName);
bool WriteStructure(OBMol& mol, const string& name);

atomType GetAtomType(OBAtom* p_atom);





//------------------------------------------------------------------------------

int main(int argc,char* argv[])
{
    /*
    OBMol imol; // input molecule
    OBMol omol; // output molecule


    string iname = "test/output.xyz";
    if( ReadStructure(imol,iname) == false ) return(1);*/

    OBMol probemol; // input probe molecule
    OBMol bulkmol; // input bulk molecule

    string probename = "test/probe.xyz";
    string bulkname = "test/block.xyz";


    if( ReadStructure(probemol,probename) == false ) return(1);
    if( ReadStructure(bulkmol,bulkname) == false ) return(1);


    /*
    int freeOx = CountFreeOxygens(imol);

    cout << "Number of atoms: " << imol.NumAtoms() << endl;
    cout << "Number of oxygens: " << CountOxygens(imol) << endl;
    cout << "Number of silicons: " << imol.NumAtoms() - CountOxygens(imol) << endl;
    cout << "Number of free oxygens: " << freeOx << endl;
    */

    //adds Hydrogens to probe and bulk
    AddHydrogens(probemol);
    AddHydrogens(bulkmol);

    cout << "Initializing fixed atoms for probe..." << endl;
    InitFixedAtoms(probemol, BOTTOM, 0.5);

    cout << "Initializing fixed atoms for bulk..." << endl;
    InitFixedAtoms(bulkmol, TOP,0.5);
    cout << "Building tleap..." << endl;

    //build TLeap
    BuildTLeap(probemol, bulkmol, "test/ctrl.in");


    return(0);
}

//------------------------------------------------------------------------------

// this will cut structure according to user setup coneAngle - in rad, coneRadius in A.
// cleanning is done by another method

void CutStructure(OBMol& imol, OBMol& omol, double coneAngle, double coneRadius)
{
    cout << endl;
    cout << "Cutting structure ..." << endl;

    // selected atoms
    // in pass 1 only silicons
    vector<unsigned int> selected_atoms;

    // first pass - select silicon atoms according to user criteria
    // babel indexes atoms from 1
    for(unsigned int i=1; i <= imol.NumAtoms(); i++){
        OBAtom* p_atom = imol.GetAtom(i);
        if( p_atom->GetAtomicNum() != 14 ) continue; // only silicons
        // position
        double x = p_atom->GetX();
        double y = p_atom->GetY();
        double z = p_atom->GetZ();


        double r = coneRadius-z*sin(coneAngle/2);

        if( r > sqrt(x*x+y*y) ){
            // select atom
            selected_atoms.push_back(i);
        }
    }

    // info
    cout << "   Number of selected silicons : " << selected_atoms.size() << endl;

    // second pass - select oxygen atoms that are connected to selected silicons

    // selected oxygen atoms will be stored as a set
    // http://www.cplusplus.com/reference/set/set/
    set<unsigned int> selected_oxygens;

    // vector items are indexed from zero
    for(size_t i=0; i < selected_atoms.size(); i++){
        // get selected atom
        OBAtom* p_atom = imol.GetAtom( selected_atoms[i] );

        // go through its neighbour atoms
        OBBondIterator    ni;

        OBAtom* p_natom = p_atom->BeginNbrAtom(ni);
        while( p_natom != NULL ){
            // this assumes that all neighbour atoms are oxygens
            // otherwise use a filter by Z

            // get atom index
            unsigned int oi = p_natom->GetIdx();

            // inser to a set, duplicities are automatically handled by a container
            selected_oxygens.insert(oi);

            // next atom
            p_natom = p_atom->NextNbrAtom(ni);
        }
    }

    // copy selected oxygens to silicons
    selected_atoms.insert(selected_atoms.end(),selected_oxygens.begin(),selected_oxygens.end());
    // create new molecule

    // reserve storage for atoms
    omol.ReserveAtoms(selected_atoms.size());

    // copy data
    for(size_t i=0; i < selected_atoms.size(); i++){
        // get selected atom
        OBAtom* p_satom = imol.GetAtom( selected_atoms[i] );

        // create new atom and copy relevant data
        OBAtom* p_tatom  = omol.NewAtom();
        p_tatom->SetAtomicNum(p_satom->GetAtomicNum()); // set Z
        p_tatom->SetVector(p_satom->GetVector()); //set coordinates
    }

    // create bonds
    omol.ConnectTheDots();
}

//------------------------------------------------------------------------------


void CleanStructure(OBMol& mol)
{
    cout << endl;
    cout << "Cleaning structure ..." << endl;

    bool nextCycleNeeded = false;
    // make list of silicon atoms that will be deleted
    vector<OBAtom*> atoms_to_be_deleted;
    // throught all atoms
    for( unsigned int i=1; i <= mol.NumAtoms(); i++ ){
        OBAtom* p_atom = mol.GetAtom(i);
        if( p_atom->GetAtomicNum() != 14 ) continue; // only silicons

        // count number of neighbour silicons
        // this needs to traverse neighbour atoms twice in two layers

        int num_of_nsilicons = 0;

        // go through first layer of neighbour atoms
        OBBondIterator    n1i;
        OBAtom* p_n1atom = p_atom->BeginNbrAtom(n1i);

        while( p_n1atom != NULL ){
            // this assumes that all neighbour atoms are oxygens

            // go through second layer of neighbour atoms
            OBBondIterator    n2i;
            OBAtom* p_n2atom = p_n1atom->BeginNbrAtom(n2i);
            while( p_n2atom != NULL ){
                if( p_n2atom != p_atom ){ // avoid self counting
                    // this assumes that all neighbour atoms are silicons
                    num_of_nsilicons++;
                }

                // next atom
                p_n2atom = p_n1atom->NextNbrAtom(n2i);
            }

            // next atom
            p_n1atom = p_atom->NextNbrAtom(n1i);
        }

        if( num_of_nsilicons <= 1 ){
            // mark atom for deletion
            // 0 - orphan SiO4
            // 1 - SiO4 bound by one -O- bridge to the rest of the structure
            atoms_to_be_deleted.push_back(p_atom);
        }

        //removes Si-O-Si-O- chains by recursion
        if( num_of_nsilicons == 1)
            nextCycleNeeded = true;

    }

    cout << "   Number of orphaned silicons : " << atoms_to_be_deleted.size() << endl;

    // delete silicon atoms from orphaned units
    for(size_t i=0; i < atoms_to_be_deleted.size(); i++){
        mol.DeleteAtom(atoms_to_be_deleted[i]);
    }

    // now delete all oxygen atoms that are orphans
    atoms_to_be_deleted.clear();

    // throught all atoms
    for( unsigned int i=1; i <= mol.NumAtoms(); i++ ){
        OBAtom* p_atom = mol.GetAtom(i);
        if( p_atom->GetAtomicNum() != 8 ) continue; // only oxygens
        if( p_atom->GetValence() == 0 ) atoms_to_be_deleted.push_back(p_atom); // no bond
    }

    cout << "   Number of orphaned oxygens : " << atoms_to_be_deleted.size() << endl;

    // delete orphaned oxygen atoms
    for(size_t i=0; i < atoms_to_be_deleted.size(); i++){
        mol.DeleteAtom(atoms_to_be_deleted[i]);
    }

    if(nextCycleNeeded)
    {
        cout << "   Detected Si-O chains, cleaning one more time ...";

        CleanStructure(mol);
    }
}

//------------------------------------------------------------------------------

// read structure in xyz format
// it solves obscure behaviour of standard openbabel xyz reader

bool ReadStructure(OBMol& mol, const string& name)
{
   ifstream ifs;

   // open file and test if it succeeded
   ifs.open(name.c_str());
   if( ! ifs ){
       cerr << ">>> ERROR: Unable to open input file : '" << name << "'!" << endl;
       return(false);
    }

   // read number of atoms
   int numofatoms = 0;
   ifs >> numofatoms;
   if( (! ifs) || (numofatoms <= 0) ){
       cerr << ">>> ERROR: Unable to read number of atoms or illegal number of atoms!" << endl;
       return(false);
    }

   // skip the rest of line and read comment line
   string buffer;
   getline(ifs,buffer);
   getline(ifs,buffer);

   // reserve storage for atoms
   mol.ReserveAtoms(numofatoms);


   // read atoms
   for(int i=0; i < numofatoms; i++){
       string symbol;
       double x,y,z;

       // read data
       ifs >> symbol >> x >> y >> z;


       if( ! ifs ){
           cerr << ">>> ERROR: Unable to read atom " << i + 1 << "!" << endl;
           return(false);
       }

       // create atom
       OBAtom* p_atom  = mol.NewAtom();
       int atomicNum = etab.GetAtomicNum(symbol.c_str());
       p_atom->SetAtomicNum(atomicNum); // set Z
       p_atom->SetVector(x,y,z); //set coordinates

   }


   // create bonds
   mol.ConnectTheDots();

   // OK, this causes troubles in standard xyz parser
   // mol.PerceiveBondOrders();

   return(true);
}
//------------------------------------------------------------------------------

// translates structure to the center of coord system, alignes x,y axis

void AlignStructure(OBMol &mol)
{

    cout << "   Aligning structure ..." << endl;
    double byMaxXcoord [2];
    double byMaxYcoord [2];
    double byMinXcoord [2];
    double byMinYcoord [2];
    double x, y;
    //initialize maxCoords and minCoords field to count translation vector and rotation matrix
    for (unsigned int i=1; i<=mol.NumAtoms();i++){
        x = mol.GetAtom(i)->GetVector().GetX();
        y = mol.GetAtom(i)->GetVector().GetY();

        if(i==0){
            byMaxXcoord[0] = x;
            byMaxYcoord[0] = x;
            byMaxXcoord[1] = y;
            byMaxYcoord[1] = y;
            byMinXcoord[0] = x;
            byMinYcoord[0] = x;
            byMinXcoord[1] = y;
            byMinYcoord[1] = y;

        }
        else
        {
            if(byMaxXcoord[0] < x){
                byMaxXcoord[0] = x;
                byMaxXcoord[1] = y;

            }
            if(byMaxYcoord[1] < y){
                byMaxYcoord[0] = x;
                byMaxYcoord[1] = y;

            }
            if(byMinXcoord[0] > x){
                byMinXcoord[0] = x;
                byMinXcoord[1] = y;

            }
            if(byMinYcoord[1] > y){
                byMinYcoord[0] = x;
                byMinYcoord[1] = y;

            }

        }
    }

    //declaring rotation matrix
    double alpha = atan(byMaxXcoord[1]/byMaxXcoord[0]);
    double beta  = atan(byMaxYcoord[1]/byMaxYcoord[0]);
    double fi = -(alpha+beta)/2;
    double M [3][3];
    M[0][0] = cos(fi); M[0][1] = sin(fi); M[0][2] = 0;
    M[1][0] =-sin(fi); M[1][1] = cos(fi); M[1][2] = 0;
    M[2][0] = 0;       M[2][1] = 0;       M[2][2] = 1;

    //declaring translation vector, assuming that z-axis is correct
    vector3 V(-(byMaxXcoord[0]+byMinXcoord[0])/2,-(byMaxYcoord[1]+byMinYcoord[1])/2,0);
    vector3 newV;

    cout << "   Rotating structure by "  << fi << " rad." << endl;
    //writing new coords to each atom
    for (unsigned int i=1; i<=mol.NumAtoms();i++){
        OBAtom* p_atom = mol.GetAtom(i);

        newV = p_atom->GetVector();
        newV+= V;
        newV*= M;
        p_atom->SetVector(newV);


    }


}


//------------------------------------------------------------------------------

// write structure in xyz format

bool WriteStructure(OBMol& mol, const string& name)
{
   ofstream ofs;

   // open file and test if it succeeded
   ofs.open(name.c_str());
   if( ! ofs ){
       cerr << ">>> ERROR: Unable to open output file : '" << name << "'!" << endl;
       return(false);
    }

   // setup standard converter
   OBConversion conv(NULL,&ofs);
   conv.SetOutFormat("XYZ");

   // save structure
   if( conv.Write(&mol) == false ){
       cerr << ">>> ERROR: Unable to save output file : '" << name << "'!" << endl;
       return(false);
   }

   return(true);
}

int CountFreeOxygens(OBMol& mol)
{
    int freeOxygens = 0;

    for (unsigned int i=1; i<=mol.NumAtoms();i++)
    {
        int num_of_nsilicons = 0;
        OBAtom* p_atom = mol.GetAtom(i);
        if(p_atom->GetAtomicNum() != 8) continue;

        OBBondIterator bi;

        OBAtom* p_siatom = p_atom->BeginNbrAtom(bi);
        while( p_siatom != NULL ){
            num_of_nsilicons++;

            // next atom
            p_siatom = p_atom->NextNbrAtom(bi);
        }

        if (num_of_nsilicons == 1)
            freeOxygens++;

    }

    return(freeOxygens);

}

int CountOxygens(OBMol& mol)
{
    int oxygens = 0;

    for (unsigned int i=1; i<=mol.NumAtoms();i++)
    {
        OBAtom* p_atom = mol.GetAtom(i);
        if(p_atom->GetAtomicNum() != 8) continue;
        oxygens++;

    }

    return(oxygens);
}

bool BuildTLeap(OBMol& probemol, OBMol& bulkmol, const string& name)
{

    // opens file
    FILE* file = fopen(name.c_str(), "w");

    // writes header - source, unit, residues
    cout << "Adding head..." << endl;
    fprintf(file, "source leaprc.ff99SB\n");
    fprintf(file, "source leaprc.silica\n");
    fprintf(file, "U = createUnit U\n");
    fprintf(file, "bulk = createResidue SRF\n");
    fprintf(file, "probe = createResidue PRB\n");

    //writes probe atoms
    cout << "Adding probe atoms..." << endl;
    for (unsigned int i=1;i<=probemol.NumAtoms();i++)
    {
        OBAtom* p_atom = probemol.GetAtom(i);
        CreateAtom(file, p_atom);
        fprintf(file, "add probe newAtom\n");

    }

    //writes bulk atoms
    cout << "Adding bulk atoms..." << endl;
    for (unsigned int i=1;i<=bulkmol.NumAtoms();i++)
    {
        OBAtom* p_atom = bulkmol.GetAtom(i);
        p_atom->SetVector(p_atom->GetX(), p_atom->GetY(), p_atom->GetZ()+120);

        CreateAtom(file, p_atom);

        fprintf(file, "add bulk newAtom\n");

    }


    WriteBonds(file, probemol, "probe");
    WriteBonds(file, bulkmol, "bulk");

    cout << "Adding footer..." << endl;
    fprintf(file, "add U bulk\n");
    fprintf(file, "add U probe\n");
    fprintf(file, "solvateBox U TIP3PBOX 10.0\n");
    fprintf(file, "saveAmberParm U hrot.parm7 hrot.rst7\n");
    fprintf(file, "quit");

    cout << "TLeap has been successfully built!" << endl;
    return(true);
}

void CreateAtom(FILE* file, OBAtom* p_atom)
{
    //writes atom to tleap file
    switch(GetAtomType(p_atom)) //decides about atom Type
    {
        case HY:    //hydrogen
            fprintf(file, "newAtom = createAtom H HY 0.4\n");
        break;

        case SI0:   //inner silicon
            fprintf(file, "newAtom = createAtom %s SI 1.1\n", GetSiName(p_atom).c_str());
        break;

        case SI1:   //outer silicon
            fprintf(file, "newAtom = createAtom %s SI 0.725\n", GetSiName(p_atom).c_str());
        break;

        case OX:    //inner oxygen
            fprintf(file, "newAtom = createAtom O OX -0.55\n");
        break;

        case OY0:   //outer oxygen with H-bond
            fprintf(file, "newAtom = createAtom O OY -0.675\n");
            //cout << "Detekovan OY0. Pocet kremiku: "<< CountNbrSilicons(p_atom) << " pocet vazeb: " << p_atom->GetValence() << endl;
        break;

        case OY1:   //outer oxygen without H-bond
            fprintf(file, "newAtom = createAtom O OY -0.9\n");
        break;
    }

    //writes atom position
    fprintf(file, "set newAtom position { %12.6f %12.6f %12.6f }\n", p_atom->GetX(), p_atom->GetY(), p_atom->GetZ());

}

string GetSiName(OBAtom* p_atom)
{
    if(IsFixedAtom(p_atom))
        return "SiX";
    else return "Si";
}

//writes all bonds of residue to tleap file
void WriteBonds(FILE* file, OBMol& mol, const string& resName)
{
    OBBondIterator bi;
    OBBond* bond = mol.BeginBond(bi); //starting bond
    int bCount = 0;

    cout << "Adding bonds..." << endl;

    while (bond != NULL) //cycle over every bond
    {
        //writes bond
        fprintf(file, "bond %s.%i %s.%i\n", resName.c_str(), bond->GetBeginAtom()->GetIdx(), resName.c_str(), bond->GetEndAtom()->GetIdx());

        //calls next bond
        bond = mol.NextBond(bi);
        bCount++;
    }

    cout << "Added " << bCount << " bonds" << endl;

}

//counts silicons in the neighbourhood of selected oxygen atom
int CountNbrSilicons(OBAtom* p_atom)
{
    int num_of_nsilicons = 0;

    OBBondIterator bi;
    OBAtom* p_siatom = p_atom->BeginNbrAtom(bi); //starting bond
    while( p_siatom != NULL ){

        if(p_siatom->GetAtomicNum() == 14)  //only silicons
            num_of_nsilicons++;

    // next atom
        p_siatom = p_atom->NextNbrAtom(bi);
    }

    return(num_of_nsilicons);
}

//checks, whether is p_atom inner silicon or outer silicon
bool isSi0(OBAtom* p_atom)
{
    bool isSi0 = true;
    OBBondIterator bi;
    OBAtom* p_oatom = p_atom->BeginNbrAtom(bi);
    while( p_oatom != NULL ){ //cycle over every bond

        if(p_oatom->GetAtomicNum() == 8 && GetAtomType(p_oatom) == OY1) //oxygen type is simplier...
            isSi0 = false;

    // next atom
        p_oatom = p_atom->NextNbrAtom(bi);
    }

    return isSi0;
}

atomType GetAtomType(OBAtom* p_atom)
{
    if (p_atom->GetAtomicNum() == 1) //Hydrogen
            return HY;

    if (p_atom->GetAtomicNum() == 8)
    {

        if (CountNbrSilicons(p_atom) == 2)
            return OX; //inner oxygen
        if (p_atom->GetValence() == 1)
            return OY1; //outer oxygen without hydrogen

        return OY0;     //outer oxygen with hydrogen
    }

    if (p_atom->GetAtomicNum() == 14)
    {

        if(isSi0(p_atom))
            return SI0; //inner silicon

        else
            return SI1; //outer silicon

    }


}

//Adds hydrogens to inner oxygens
void AddHydrogens(OBMol& mol)
{
    double hDistance = 1.0; //distance of hydrogen from oxygen
    int currAtoms = mol.NumAtoms();
    int hAdded = 0;
    cout << "Adding hydrogens...." << endl;

    for (int i=1; i<=currAtoms; i++)
    {

        OBAtom* p_atom = mol.GetAtom(i);

        if(GetAtomType(p_atom) == OY1) //only inner oxygens without hydrogen
        {

            OBBondIterator bi;
            vector3 hDirection;


            OBAtom* si_atom = p_atom->BeginNbrAtom(bi);
            hDirection = p_atom->GetVector() - si_atom->GetVector();
            hDirection.normalize(); //direction of new hydrogen


            OBAtom* h_atom = mol.NewAtom(); //new atom
            h_atom->SetAtomicNum(1); //sets atom to hydrogen
            h_atom->SetVector(p_atom->GetVector() + hDistance*hDirection); //sets position

            mol.AddBond(p_atom->GetIdx(), h_atom->GetIdx(), 2); //creates bond between hydrogen and oxygen
            hAdded++;
        }



    }

    cout << "Added " << hAdded << " hydrogens to residue..." << endl;

}

//marks silicons to be fixed in simulation...
bool IsFixedAtom(OBAtom* p_atom, double refZ, double tolerance)
{
    if (tolerance !=0)
    {
        if (abs(p_atom->GetZ()-refZ) <= tolerance){
            return (true);
        }
    }
    else
    {
        if (p_atom->GetIsotope() == 1)
            return (true);
    }

    return (false);
}


void InitFixedAtoms(OBMol& mol, fixType fType, double tolerance)
{
    int cnt = 0;
    double fixZ = GetFixedZValue(mol, fType);

    for (unsigned int i=1; i <= mol.NumAtoms(); i++)
    {
        if(IsFixedAtom(mol.GetAtom(i), fixZ, tolerance))
        {

            mol.GetAtom(i)->SetIsotope(1);
            cnt++;
        }
    }

    cout << "Initialized " << cnt << " fixed atoms." << endl;
}

//returns Z-value of the first fixed silicon
double GetFixedZValue(OBMol& mol, fixType fType)
{
    double fixZ = 0;
    //initialize fixZ
    OBAtom* p_atom = mol.GetAtom(1);
    fixZ = p_atom->GetZ();

    for (unsigned int i=2; i<= mol.NumAtoms(); i++)
    {
        p_atom = mol.GetAtom(i);

        if(p_atom->GetAtomicNum() == 14)
        {
            switch(fType){
                case TOP:
                    if(p_atom->GetZ() > fixZ)
                        fixZ = p_atom->GetZ();
                break;

                case BOTTOM:
                    if(p_atom->GetZ() < fixZ)
                        fixZ = p_atom->GetZ();
                break;
            }
        }
    }

    return(fixZ);
}


//------------------------------------------------------------------------------
